#!/usr/bin/python
#-*-coding:utf-8-*-

###############################################################
## Name       : auto_testbench
## Author     : xiaotu
## Time       : 2022-03-26 04:15:12
## Description:
## 
## 
###############################################################

import sys
import os
import re
import argparse
import math
import difflib

class Signal: #{{{
    name_width = 1
    widh_width = 1
    def __init__(self, name, type="wire", port="none", width = ""):
        self.name   = name
        self.type   = type
        self.port   = port
        self.width  = width
        self.ctrl_sig = 0 #默认不是控制信号
        self.find_valid_sig = 0
        self.find_ready_sig = 0
        self.bus_list = []
        Signal.name_width  = max(Signal.name_width, len(name))
        Signal.widh_width  = max(Signal.widh_width, len(width))

    def __str__(self):
        str = "Class Signal\n"
        str = str + "    name   : %s\n" % self.name
        str = str + "    type   : %s\n" % self.type
        str = str + "    port   : %s\n" % self.port
        str = str + "    width  : %s\n" % self.width
        str = str + "    max name_width = %s\n" % Signal.name_width
        return str

    def valid_en(self):
        if re.search(r"valid$", self.name):
            self.ctrl_sig = 1
            return 1
        else:
            return 0

    def ready_en(self):
        if re.search(r"ready$", self.name):
            self.ctrl_sig = 1
            return 1
        else:
            return 0

    def org_valid_sig(self):
        if self.ready_en() == 1:
            self.valid_sig = re.sub(r"ready$", "valid", self.name)
        else:
            self.valid_sig = self.name + "_valid"
        return self.valid_sig

    def org_ready_sig(self):
        if self.valid_en() == 1:
            self.ready_sig = re.sub(r"valid$", "ready", self.name)
        else:
            self.ready_sig = self.name + "_ready"
        return self.ready_sig

    def set_valid_sig(self, sig):
        self.valid_sig  = sig
        self.find_valid_sig = 1

    def set_ready_sig(self, sig):
        self.ready_sig = sig
        self.find_ready_sig = 1

    def valid_get_bus(self, Sig):
        self.bus_list.append(Sig)

pass#}}}

def debug_print(*ln): #{{{
    if debug:
        print(ln)
pass #}}}

def write_list(list): #{{{
    for line in list:
        print(line)
pass #}}}

def input_args_proc(): #{{{
    global debug
    global ver
    global pwd_path
    global pwd_dict
    parser = argparse.ArgumentParser(description="argparse info")
    parser.add_argument('-o', action='store_true', default=False, help='open this script')
    parser.add_argument('-t', default=False, help='top module name')
    parser.add_argument('-d', action='store_true', default=False, help='debug mode')
    parser.add_argument('-f', help='file/folder path')
    parser.add_argument('-v', action='store_true', default=False, help='auto verification code gen')
    result = parser.parse_args()
    debug = result.d
    if result.o == True:
        os.system("gvim %s" % __file__)
        sys.exit(0)
    if result.t == False:
        top = result.f
        top = os.path.splitext(os.path.basename(top))[0]
    else:
        top  = result.t
    path = result.f
    ver  = result.v
    pwd_dict = os.getcwd() + "/"
    pwd_path = os.getcwd() + "/" + path
    debug_print(top, path, pwd_path)
    return top, path
pass #}}}

def find_top_rtl(top, path): #{{{
    pos = os.path.splitext(path)[1]
    if pos == ".v" and os.path.isfile(path):        
        file_path = path
    else:
        file_path = path + "/" + top + ".v"
    if os.path.isfile(file_path):
        return file_path
    else:
        print("no input verilog file : %s" % file_path)
        sys.exit()
pass #}}}

def head_tail_split(handle, head = "\n", tail = "\n", mode = 0): #{{{
    '''
    mode用来标记首尾是否被包含
    mode = 0首尾行都不要/1首尾行都要/2要首行不要尾行/3要尾行不要首行
    '''
    ret = []
    shot_en = 0
    head_en = 0
    tail_en = 0
    for line in handle:
        re0 = re.search(r"%s" % head, line)
        re1 = re.search(r"%s" % tail, line)
        if re0 and re1: #一定是尾巴，但是不一定是头
            tail_en = 1
            shot_en = 0
            if shot_en == 0:#还没有命中，那么一定是头了
                head_en = 1
        elif re0:
            if shot_en == 0:#命中了头
                head_en = 1
            shot_en = 1
        elif re1:
            if shot_en == 1:#一定是尾巴
                tail_en = 1
            shot_en = 0
        else:
            head_en = 0
            tail_en = 0

        if mode == 0:
            if shot_en == 1 and head_en == 0 and tail_en == 0:
                ret.append(line)
        elif mode == 1:
            if shot_en == 1 or head_en == 1 or tail_en == 1:
                ret.append(line)
        elif mode == 2:
            if (shot_en == 1 or head_en == 1) and tail_en == 0:
                ret.append(line)
        else:
            if (shot_en == 1 or tail_en == 1) and head_en == 0:
                ret.append(line)
    return(ret)
pass #}}}

def del_note_code(handle):#{{{
    ret_handle = []
    del_flag = 0
    for line in handle:
        re_s  = re.search(r"^\/\*", line)
        re_e  = re.search(r"\*\/", line)
        re_se = re.search(r"^\/\*.*\*\/$", line)
        line  = re.sub(r"\/\/.*$", "", line)
        if re_se:
            continue
        elif re_s:
            del_flag = 1
        elif re_e:
            del_flag = 0
            continue

        if del_flag != 1:
            ret_handle.append(line)

    return ret_handle
#}}}

def sys_rtl_sig_para(top, path): #{{{
    global sig_list
    global valid_list
    global ready_list
    global para_list
    global para_hash
    global signal_hash
    sig_list  = []
    valid_list= []
    ready_list= []
    para_list = []
    para_hash = {}
    signal_hash = {}
    with open (path, "r") as rtl:
        handle = rtl.readlines()
        handle = del_note_code(handle)
        rtl_line = head_tail_split(handle, "^module\s+" + top, r"^endmodule", 1)
        note_flag = 0
        #print(rtl_line)
        for line in rtl_line:
            re1 = re.search(r"^\s*(input|output|wire|reg)(\s+wire|\s+reg)?\s*(\[.*\])?\s*([\s,\w]+)\s*", line)
            re2 = re.search(r"^\s*(parameter)\s+(\w+)\s*=\s*([\$\(\)\w\']+)", line)
            type = "wire"
            port = "none"
            width = ""

            if re.search(r"^\/\*", line):
                note_flag = 1
            elif re.search(r"\*\/$", line):
                note_flag = 0
                continue
            
            if note_flag == 1:
                continue

            if re1:
                #print(line, re1.group(4))
                if re1.group(3):
                    width = re1.group(3)
                if re.match(r"input|output", re1.group(1)):
                    port = re1.group(1).strip()
                else:
                    type = re1.group(1).strip()
                if re1.group(2):
                    type = re1.group(2).strip()
                for sig in re1.group(4).split(","):
                    name = sig.strip()
                    if name != "":
                        s = Signal(name, type, port, width)
                        if port == "input" or port == "output":
                            if re.search(r"valid$", name):
                                valid_list.append(name)
                            elif re.search(r"ready$", name):
                                ready_list.append(name)
                            sig_list.append(s)
                            signal_hash[s.name] = s
                            #print(s)
            if re2:
                param = re2.group(2)
                value = re2.group(3)
                para_hash[param] = value;
                para_list.append(param)
pass #}}}

def gen_inst(top): #{{{
    global para_list
    global sig_list
    ret = "//-------------------------------------{{{rtl inst\n"
    if len(para_list) == 0:
        ret += "%s u_%s(\n" % (top, top)
    else:
        tail_cnt = 0
        ret += "%s #(\n" % top
        for para in para_list:
            #print(para)
            ret += "    .%s(%s)" % (para, para)
            if tail_cnt != len(para_list)-1:
                ret += ",\n"
            tail_cnt += 1
        ret += ") \nu_%s(\n" % top
    #print(ret)

    tail_cnt = 0
    for sig in sig_list:
        ret += "    .%s(%s)" % (sig.name, sig.name)
        if tail_cnt != len(sig_list)-1:
            ret += ",\n"
        tail_cnt += 1
    ret += "\n);\n"
    ret += "//-------------------------------------}}}\n\n"
    
    #print(ret)
    return ret
pass #}}}

def gen_para_dec(): #{{{
    global para_list
    global para_hash
    ret = "//-------------------------------------{{{parameter declare\n"
    for para in para_list:
        ret += "parameter %s = %s;\n" % (para, para_hash[para])
    #print(ret)
    ret += "//-------------------------------------}}}\n"
    return ret
pass #}}}

def gen_sig_dec(): #{{{
    global sig_list
    ret = "//-------------------------------------{{{signal declare\n"
    for sig in sig_list:
        ret += "logic %s %s;\n" % (sig.width, sig.name)
    #print(ret)
    ret += "//-------------------------------------}}}\n"
    return ret
pass #}}}

def gen_force_dec(): #{{{
    global sig_list
    
    
    ret  = ""
    ret_valid_force    = "//-------------------------------------{{{valid sig assign\n"
    ret_ready_force    = "//-------------------------------------{{{ready sig assign\n"
    ret_ctrl_sig_force = "//-------------------------------------{{{data  sig assign\n"
    ret_other_force    = "//-------------------------------------{{{other sig assign\n"
    
    ret_other_force += "initial begin\n"

    for sig in sig_list:
        if sig.port == "input" and sig.name != "clk" and sig.name != "rst_n":
            if sig.valid_en() == 1:
                ret_valid_force += "always @(posedge clk or negedge rst_n)begin\n"
                ret_valid_force += "    if(~rst_n)begin\n"
                ret_valid_force += "        %s <= 0;\n" % sig.name
                ret_valid_force += "    end\n"
                ret_valid_force += "    else if(%s || ~%s)begin\n" % (sig.ready_sig, sig.name)
                ret_valid_force += "        %s <= $urandom;\n" % sig.name
                ret_valid_force += "    end\n"
                ret_valid_force += "end\n"
                ret_valid_force += "\n"
            elif sig.ready_en() == 1:
                ret_ready_force += "always @(posedge clk or negedge rst_n)begin\n"
                ret_ready_force += "    if(~rst_n)begin\n"
                ret_ready_force += "        %s <= 0;\n" % sig.name
                ret_ready_force += "    end\n"
                ret_ready_force += "    else begin\n"
                ret_ready_force += "        %s <= $urandom;\n" % sig.name
                ret_ready_force += "    end\n"
                ret_ready_force += "end\n"
                ret_ready_force += "\n"
            elif sig.find_valid_sig == 1 and sig.find_ready_sig == 1:
                ret_ctrl_sig_force += "always @(posedge clk or negedge rst_n)begin\n"
                ret_ctrl_sig_force += "    if(~rst_n)begin\n"
                ret_ctrl_sig_force += "        %s <= 'x;\n" % sig.name
                ret_ctrl_sig_force += "    end\n"
                ret_ctrl_sig_force += "    else if(%s && %s)begin\n"% (sig.valid_sig, sig.ready_sig)
                ret_ctrl_sig_force += "        %s <= $urandom;\n" % sig.name
                ret_ctrl_sig_force += "    end\n"
                ret_ctrl_sig_force += "    else if(%s == 0)begin\n"% (sig.valid_sig)
                ret_ctrl_sig_force += "        %s <= $urandom;\n" % sig.name
                ret_ctrl_sig_force += "    end\n"
                ret_ctrl_sig_force += "end\n"
                ret_ctrl_sig_force += "\n"
            elif sig.find_valid_sig == 1:
                ret_ctrl_sig_force += "always @(posedge clk or negedge rst_n)begin\n"
                ret_ctrl_sig_force += "    if(~rst_n)begin\n"
                ret_ctrl_sig_force += "        %s <= 'x;\n" % sig.name
                ret_ctrl_sig_force += "    end\n"
                ret_ctrl_sig_force += "    else if(%s == 0)begin\n"% (sig.valid_sig)
                ret_ctrl_sig_force += "        %s <= $urandom;\n" % sig.name
                ret_ctrl_sig_force += "    end\n"
                ret_ctrl_sig_force += "end\n"
                ret_ctrl_sig_force += "\n"
            else:
                ret_other_force += "    %s = $urandom;\n" % sig.name

    ret_valid_force += "//-------------------------------------}}}\n\n"
    ret_ready_force += "//-------------------------------------}}}\n\n"
    ret_ctrl_sig_force += "//-------------------------------------}}}\n\n"

    ret_other_force += "    `DELAY(50, clk);\n"
    ret_other_force += "end\n"
    ret_other_force += "\n"
    ret_other_force += "//-------------------------------------}}}\n\n"

    ret = ret_valid_force + ret_ready_force + ret_ctrl_sig_force + ret_other_force
    return ret
pass #}}}

def gen_top(top, para_dec, sig_dec, force_dec, inst, ver_proc): #{{{
    ret = '''`define DELAY(N, clk) begin \\
	repeat(N) @(posedge clk);\\
	#1ps;\\
end
'''
    if ver:
        ret += "\nimport %s_pkg::*;\n" % top
    
    ret += '''
module testbench();

//-------------------------------------{{{common cfg
timeunit 1ns;
timeprecision 1ps;
initial $timeformat(-9,3,"ns",6);

string tc_name;
int tc_seed;

initial begin
    if(!$value$plusargs("tc_name=%s", tc_name)) $error("no tc_name!");
    else $display("tc name = %0s", tc_name);
    if(!$value$plusargs("ntb_random_seed=%0d", tc_seed)) $error("no tc_seed");
    else $display("tc seed = %0d", tc_seed);
end
//-------------------------------------}}}

'''
    ret += para_dec
    ret += "\n"
    ret += sig_dec

    ret += '''
//-------------------------------------{{{clk/rst cfg
initial forever #5ns clk = ~clk;
initial begin
    rst_n = 1'b0;
	`DELAY(30, clk);
	rst_n = 1'b1;
end
initial begin
    #100000ns $finish;
end
//-------------------------------------}}}

'''
    ret += force_dec
    ret += inst
    if ver:
        ret += ver_proc
    ret += "endmodule"
    return ret

pass #}}}

def gen_filelist(top, path, pre=""): #{{{
    ret = "+libext+.v+.sv\n"
    ret += "-y %s" % pwd_dict + pre + "\n"
    ret += "\n"
    if ver:
        ret += "../ver/%s_pkg.sv\n" % top
    ret += pwd_path + "\n"
    ret += "../top/testbench.sv"
    return ret
pass #}}}

def gen_verification(top, path, testbench, lst, pkg): #{{{
    root_path = sys.path[0]
    #print(root_path)
    if os.path.exists("%s_verification" % top):
        os.system("rm -rf %s_verification_bak" % (top))
        os.system("mv -f %s_verification %s_verification_bak" % (top, top))
    os.system("mkdir -p %s_verification" % top)
    os.system("cp %s/vcs_demo/* ./%s_verification -rf" %(root_path, top))
    os.system("chmod a+x ./%s_verification/cfg/check_fail.pl" % top)
    with open ("./%s_verification/top/testbench.sv" % top, "w") as hd:
        hd.write(testbench)
    with open ("./%s_verification/cfg/tb.f" % top, "w") as hd:
        hd.write(lst)
    with open ("./%s_verification/ver/%s_pkg.sv" % (top, top), "w") as hd:
        hd.write(pkg)
    print("##====================================================================##")
    print("Gen over! please cd ./%s_verification/sim" % top)
    print("You need modify ./%s_verification/top/testbench.sv" % top)
    print("    like cp ./%s_verification_bak/top/testbench.sv ./%s_verification/top/" % (top, top))
    print("You need modify ./%s_verification/cfg/tb.f" % top)
    print("    like cp ./%s_verification_bak/cfg/tb.f ./%s_verification/cfg/" % (top, top))
    print("##====================================================================##")
#}}}

def similar_diff_ratio(str1, str2): #{{{
    return difflib.SequenceMatcher(None, str1, str2).ratio()
pass #}}}

def find_valid_ready(): #{{{
    for sig in sig_list:
        if sig.valid_en() == 1:
            ratio = 0.4
            ready_sig = sig.org_ready_sig()
            for ready in ready_list:
                if signal_hash[ready].port == sig.port:
                    continue
                ratio_tmp = similar_diff_ratio(sig.name, ready)
                if ratio_tmp > ratio:
                    ratio = ratio_tmp
                    sig.set_ready_sig(ready)
                    #debug_print(sig.name, ready, ratio_tmp)
pass #}}}

def find_sig_valid(): #{{{
    for sig in sig_list:
        if sig.valid_en() == 0 and sig.ready_en() == 0:
            ratio = 0.4
            ready_sig = sig.org_ready_sig()
            for ready in ready_list:
                if signal_hash[ready].port == sig.port:
                    continue
                ratio_tmp = similar_diff_ratio(sig.name, ready)
                if ratio_tmp > ratio:
                    ratio = ratio_tmp
                    sig.set_ready_sig(ready)
                    debug_print(sig.name, ready, ratio_tmp)
            ratio = 0.5
            valid_sig = sig.org_valid_sig()
            for valid in valid_list:
                if signal_hash[valid].port != sig.port:
                    continue
                ratio_tmp = similar_diff_ratio(sig.name, valid)
                if ratio_tmp > ratio:
                    ratio = ratio_tmp
                    sig.set_valid_sig(valid)
                    #debug_print(sig.name, valid, ratio_tmp)
pass #}}}

def gain_valid_bus_signal():#{{{
    for sig in sig_list:
        if sig.find_valid_sig == 1:
             valid = signal_hash[sig.valid_sig]
             valid.valid_get_bus(sig)
#}}}

def gen_pkg(top): #{{{
    global check_valid
    check_valid = ""

    ret  = "package %s_pkg;\n\n" % top
    ret += "    parameter ERROR_DEBUG_CNT = 5;\n"
    for para in para_list:
        ret += "    parameter %s = %s;\n" % (para, para_hash[para])
    ret += "\n"

    ret += "    int error_cnt = 0;\n"
    ret += "    bit check_en  = 0;\n\n"

    for valid_name in valid_list:
        valid = signal_hash[valid_name]
        if len(valid.bus_list) != 0:
            ret += "    typedef struct{\n"
            for sig in valid.bus_list:
                ret += "        bit %s %s;\n" %(sig.width, sig.name)
            ret += "    } %s_struct;\n" % valid.name
            if valid.port == "output":
                check_valid = valid.name
                ret += "    %s_struct rm_q[$];\n" % (valid.name)
            ret += "    %s_struct %s_bus_q[$];\n\n" % (valid.name, valid.name)
    ret += "endpackage"
    return ret
#}}}

def gen_ver_proc():#{{{
    ret  = "//-------------------------------------{{{auto_verification\n"
    
    ret += "task in_queue_gain();\n"
    ret += "  while(1)begin\n"
    ret += "    @(negedge clk);\n"
    for valid_name in valid_list:
      valid = signal_hash[valid_name]
      if len(valid.bus_list) != 0 and valid.port == "input":
          ret += "    if(%s && %s)begin\n" % (valid.name, valid.ready_sig)
          ret += "      %s_struct %s_dat;\n" % (valid.name, valid.name)
          for sig in valid.bus_list:
              ret += "      %s_dat.%s = %s;\n" % (valid.name, sig.name, sig.name)
          ret += "      %s_bus_q.push_back(%s_dat);\n" % (valid.name, valid.name)
          ret += "    end//if-end \n"
    ret += "  end//while-end \n"
    ret += "endtask: in_queue_gain\n\n"

    ret += "task out_queue_gain();\n"
    ret += "  while(1)begin\n"
    ret += "    @(negedge clk);\n"
    for valid_name in valid_list:
      valid = signal_hash[valid_name]
      if len(valid.bus_list) != 0 and valid.port == "output":
          ret += "    if(%s && %s)begin\n" % (valid.name, valid.ready_sig)
          ret += "      %s_struct %s_dat;\n" % (valid.name, valid.name)
          for sig in valid.bus_list:
              ret += "      %s_dat.%s = %s;\n" % (valid.name, sig.name, sig.name)
          ret += "      %s_bus_q.push_back(%s_dat);\n" % (valid.name, valid.name)
          ret += "    end//if-end \n"
    ret += "  end//while-end \n"
    ret += "endtask: out_queue_gain\n\n"

    ret += "task rm_queue_gain();\n"
    for valid_name in valid_list:
        valid = signal_hash[valid_name]
        if len(valid.bus_list) != 0:
            ret += "  %s_struct %s_dat;\n" % (valid.name, valid.name)
    ret += "  //while(1)begin\n"
    for valid_name in valid_list:
        valid = signal_hash[valid_name]
        if len(valid.bus_list) != 0 and valid.port == "input":
            ret += "    //wait(%s_bus_q.size > 0);\n" % valid.name
            ret += "    //%s_dat = %s_bus_q.pop_front();\n" % (valid.name, valid.name)
        if len(valid.bus_list) != 0 and valid.port == "output":
            ret += "    //rm_q.push_back(%s_dat);\n" % valid.name
            #print(valid.name)
    ret += "  //end\n"
    ret += "endtask: rm_queue_gain\n\n"

    ret += "task queue_check();\n"
    ret += "  while(1)begin\n"
    ret += "    %s_struct rm_data;\n" % check_valid
    #print(check_valid)
    ret += "    %s_struct dual_data;\n" % check_valid
    ret += "    wait(%s_bus_q.size() > 0);\n" % check_valid
    ret += "    dual_data = %s_bus_q.pop_front();\n" % check_valid
    ret += "    if(rm_q.size() == 0) begin\n"
    ret += "      $display(\"dual_data = %0p, rm_queue.size = 0\", dual_data);\n"
    ret += "      error_cnt += 1;\n"
    ret += "    end\n"
    ret += "    else begin\n"
    ret += "      rm_data = rm_q.pop_front();\n"
    ret += "      if(dual_data != rm_data)begin\n"
    ret += "        error_cnt += 1;\n"
    ret += "        $display(\"dual_data(%0p) != rm_data(%0p) at %t\", dual_data, rm_data, $realtime);\n"
    ret += "      end\n"
    ret += "      else begin\n"
    ret += "        //$display(\"dual_data(%0p) == rm_data(%0p) at %t\", dual_data, rm_data, $realtime);\n"
    ret += "      end\n"
    ret += "    end\n"
    ret += "    if(error_cnt >= ERROR_DEBUG_CNT) begin\n"
    ret += "      $display(\"Check Error!!!\");\n"
    ret += "      $finish;\n"
    ret += "    end\n"
    ret += "  end\n"
    ret += "endtask: queue_check\n\n"

    ret += '''initial begin
  fork
    in_queue_gain();
    out_queue_gain();
    rm_queue_gain();
    if(check_en == 1) queue_check();
  join_none
end\n\n'''
    ret += "//-------------------------------------}}}\n"
    return ret
#}}}

def main(): #{{{
    (top, path) = input_args_proc()
    file_path   = find_top_rtl(top, path)
    #print(top, file_path)
    sys_rtl_sig_para(top, file_path)

    #对信号进行一系列的分析，寻找关联关系
    find_valid_ready()
    find_sig_valid()
    gain_valid_bus_signal()

    #pkg
    pkg = gen_pkg(top)

    #生成testbench
    inst = gen_inst(top)
    para_dec = gen_para_dec()
    sig_dec  = gen_sig_dec()
    force_dec= gen_force_dec()
    ver_proc = gen_ver_proc()
    testbench = gen_top(top, para_dec, sig_dec, force_dec, inst, ver_proc)

    #生成filelist
    lst = gen_filelist(top, path)

    #完成验证环境
    gen_verification(top, path, testbench, lst, pkg)
    
pass #}}}

if __name__ == "__main__":
    main()
